{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Document Classification\n",
    "This tutorial will show how to perform document classification in Tribuo, using a variety of different methods to extract features from the text. We'll use the venerable [20-newsgroups dataset](http://qwone.com/~jason/20Newsgroups/) where the task is to predict what newsgroup a particular post is from, though this tutorial would be equally applicable to any document classification task (including tasks like sentiment analysis). We're going to train a simple logistic regression with fixed hyperparameters using a variety of feature extraction methods. The aim is to show how to extract features from text rather than focusing on the performance, as using a more powerful model like XGBoost, or performing hyperparameter optimization on the logisitic regression will likely improve the performance of all the feature extraction techniques.\n",
    "\n",
    "# Setup\n",
    "\n",
    "You'll need a copy of the 20 newsgroups dataset, so first download and unpack it:\n",
    "\n",
    "```\n",
    "wget http://qwone.com/~jason/20Newsgroups/20news-bydate.tar.gz\n",
    "mkdir 20news\n",
    "cd 20news\n",
    "tar -zxf ../20news-bydate.tar.gz\n",
    "```\n",
    "\n",
    "This leaves you with two directories `20news-bydate-train` and `20news-bydate-test`, which contain the standard train and test split for this data.\n",
    "\n",
    "20 newsgroups comes in a fairly standard format, the dataset is represented by a set of directories where the directory name is the class label, and the directory contains a collection of documents with one document in each file. Each file is a single Usenet post. For the purposes of this tutorial, we'll use the subject and body of the post as the input text for classification.\n",
    "\n",
    "Here's an example:\n",
    "\n",
    "```\n",
    "$ ls 20news-bydate-train/\n",
    "alt.atheism/               comp.sys.mac.hardware/  rec.motorcycles/     sci.electronics/         talk.politics.guns/\n",
    "comp.graphics/             comp.windows.x/         rec.sport.baseball/  sci.med/                 talk.politics.mideast/\n",
    "comp.os.ms-windows.misc/   misc.forsale/           rec.sport.hockey/    sci.space/               talk.politics.misc/\n",
    "comp.sys.ibm.pc.hardware/  rec.autos/              sci.crypt/           soc.religion.christian/  talk.religion.misc/\n",
    "$ ls 20news-bydate-train/comp.graphics/\n",
    "37261  37949  38233  38270  38305  38344  38381  38417  38454  38489  38525  38562  38598  38633  38668  38703  38739\n",
    "37913  37950  38234  38271  38306  38346  38382  38418  38455  38490  38526  38563  38599  38634  38669  38704  38740\n",
    "37914  37951  38235  38272  38307  38347  38383  38420  38456  38491  38527  38564  38600  38635  38670  38705  38741\n",
    "37915  37952  38236  38273  38308  38348  38384  38421  38457  38492  38528  38565  38601  38636  38671  38706  38742\n",
    "...\n",
    "```\n",
    "\n",
    "As this is a pretty common format, Tribuo has a specific `DataSource` which can be used to read in this sort of data, `org.tribuo.data.text.DirectoryFileSource`.\n",
    "\n",
    "We're going to use the classification experiments jar, along with the ONNX jar which provides support for loading in contextual word embedding models like [BERT](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%jars ./tribuo-classification-experiments-4.3.0-SNAPSHOT-jar-with-dependencies.jar\n",
    "%jars ./tribuo-onnx-4.3.0-SNAPSHOT-jar-with-dependencies.jar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll also need a selection of imports from the `org.tribuo.data.text` package, along with the usual imports from `org.tribuo` and `org.tribuo.classification` we use when working with classification tasks. We'll load in the BERT support from the `org.tribuo.interop.onnx.extractors` package. Tribuo's BERT support loads in models and tokenizers from [HuggingFace's Transformer](https://huggingface.co/transformers/) package, and can be easily extended to support non-BERT models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util.Collections;\n",
    "import java.nio.file.Paths;\n",
    "import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;\n",
    "import com.oracle.labs.mlrg.olcut.util.Pair;\n",
    "import org.tribuo.*;\n",
    "import org.tribuo.data.text.*;\n",
    "import org.tribuo.data.text.impl.*;\n",
    "import org.tribuo.dataset.MinimumCardinalityDataset;\n",
    "import org.tribuo.classification.*;\n",
    "import org.tribuo.classification.evaluation.*;\n",
    "import org.tribuo.classification.sgd.linear.LinearSGDTrainer;\n",
    "import org.tribuo.classification.sgd.objectives.LogMulticlass;\n",
    "import org.tribuo.interop.onnx.extractors.BERTFeatureExtractor;\n",
    "import org.tribuo.math.optimisers.AdaGrad;\n",
    "import org.tribuo.transform.*;\n",
    "import org.tribuo.transform.transformations.IDFTransformation;\n",
    "import org.tribuo.util.tokens.universal.UniversalTokenizer;\n",
    "import org.tribuo.util.Util;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll instantiate a few classes that we'll use throughout this tutorial, the label factory, the evaluator and the paths to the train and test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "var labelFactory = new LabelFactory();\n",
    "var labelEvaluator = new LabelEvaluator();\n",
    "var trainPath = Paths.get(\".\",\"20news\",\"20news-bydate-train\");\n",
    "var testPath = Paths.get(\".\",\"20news\",\"20news-bydate-test\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extracting features from text\n",
    "Much of the work of machine learning is in presenting an appropriate representation of the data to the model. This is especially true when working with text data, as there is a plethora of approaches for converting text into the numbers that ML algorithms operate on. The `DirectoryFileSource` allows the user to choose the feature extraction, as it requires a `TextFeatureExtractor` which converts the `String` representing the input text into a Tribuo `Example`. We'll cover several different implementations of the `TextFeatureExtractor` interface in this tutorial, and we expect that users will implement it in their own classes to cope with specific feature extraction requirements.\n",
    "\n",
    "We'll start with the simplest approach, a \"bag of words\", where each document is represented by the counts of the words in that document. This means the feature space is equal to the number of words, and most documents only have a positive value for a small number of words (as most words don't appear in any given document). This is particularly well suited to Tribuo's sparse vector representation of examples, and this suitability for NLP tasks is the reason that Tribuo is designed this way. Of course, first we'll need to tell the extractor what a word is, and for this we use a `Tokenizer`. Tokenizers split up a `String` into a stream of tokens. Tribuo provides several basic tokenizers, and an interface for tokenization. We're going to use Tribuo's `UniversalTokenizer` which is descended from tokenizers developed at Sun Labs in the 90s, and used in a variety of Sun products since that time. First we'll use a strict bag of words where each feature takes the value `1` if that word is present in the document, and `0` otherwise. We'll use Tribuo's `BasicPipeline` which can convert `String`s into features, and pass it to the basic `TextFeatureExtractor` implementation, helpfully called `TextFeatureExtractorImpl`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "var tokenizer = new UniversalTokenizer();\n",
    "var bowPipeline = new BasicPipeline(tokenizer,1);\n",
    "var bowExtractor = new TextFeatureExtractorImpl<Label>(bowPipeline);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We're now almost ready to make our train and test data sources, and load in the data. The `DirectoryFileSource` also accepts an array of `DocumentPreprocessor`s which can be used to transform the text before feature extraction takes place. We're going to use a specific preprocessor (`NewsPreprocessor`) which standardises the 20 newsgroups data by stripping out the mail headers and returning only the subject and the body of the email. We'll also lowercase all the text using the `CasingPreprocessor` to slightly reduce the space we're working in. In general the preprocessors are dataset and task specific, which is why Tribuo doesn't ship with many implementations as in most cases users will need to write one from scratch for their specific task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "var newsProc = new NewsPreprocessor();\n",
    "var lowercase = new CasingPreprocessor(CasingPreprocessor.CasingOperation.LOWERCASE);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll make a helper function to load the data sources and create the datasets. We're also going to restrict the test dataset so it only contains valid examples, as 20 newsgroups has some test examples that share no words with the train examples (and so have no features we could use to make predictions with).\n",
    "\n",
    "Let's check our datasets and see if everything has loaded in correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bow training data size = 11314, number of features = 122024, number of classes = 20\n",
      "bow testing data size = 7532, number of features = 122024, number of classes = 20\n"
     ]
    }
   ],
   "source": [
    "public Pair<Dataset<Label>,Dataset<Label>> mkDatasets(String name, TextFeatureExtractor<Label> extractor) {\n",
    "    var trainSource = new DirectoryFileSource<>(trainPath,labelFactory,extractor,newsProc,lowercase);\n",
    "    var testSource = new DirectoryFileSource<>(testPath,labelFactory,extractor,newsProc,lowercase);\n",
    "    var trainDS = new MutableDataset<>(trainSource);\n",
    "    var testDS = new ImmutableDataset<>(testSource,trainDS.getFeatureIDMap(),trainDS.getOutputIDInfo(),true);\n",
    "    System.out.println(String.format(name + \" training data size = %d, number of features = %d, number of classes = %d\",trainDS.size(),trainDS.getFeatureMap().size(),trainDS.getOutputInfo().size()));\n",
    "    System.out.println(String.format(name + \" testing data size = %d, number of features = %d, number of classes = %d\",testDS.size(),testDS.getFeatureMap().size(),testDS.getOutputInfo().size()));\n",
    "    return new Pair<>(trainDS,testDS);\n",
    "}\n",
    "\n",
    "var bowPair = mkDatasets(\"bow\",bowExtractor);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've loaded in 11,314 training documents containing 122,024 unique words and 7,532 test documents, each with the expected 20 classes.\n",
    "\n",
    "Now we're ready to train a model. Let's start with a simple logistic regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training the model on BoW features took (00:00:09:601)\n",
      "\n",
      "Class                                n          tp          fn          fp      recall        prec          f1\n",
      "soc.religion.christian             398         352          46         110       0.884       0.762       0.819\n",
      "rec.autos                          396         344          52          63       0.869       0.845       0.857\n",
      "talk.religion.misc                 251         166          85         120       0.661       0.580       0.618\n",
      "comp.windows.x                     395         283         112          55       0.716       0.837       0.772\n",
      "rec.sport.baseball                 397         370          27          45       0.932       0.892       0.911\n",
      "comp.graphics                      389         293          96         143       0.753       0.672       0.710\n",
      "talk.politics.mideast              376         283          93          11       0.753       0.963       0.845\n",
      "comp.sys.ibm.pc.hardware           392         277         115         160       0.707       0.634       0.668\n",
      "sci.med                            396         323          73          43       0.816       0.883       0.848\n",
      "comp.os.ms-windows.misc            394         272         122          87       0.690       0.758       0.722\n",
      "sci.crypt                          396         349          47          23       0.881       0.938       0.909\n",
      "comp.sys.mac.hardware              385         283         102          96       0.735       0.747       0.741\n",
      "misc.forsale                       390         341          49          63       0.874       0.844       0.859\n",
      "rec.motorcycles                    398         364          34          23       0.915       0.941       0.927\n",
      "talk.politics.misc                 310         182         128          94       0.587       0.659       0.621\n",
      "sci.electronics                    393         272         121         135       0.692       0.668       0.680\n",
      "rec.sport.hockey                   399         367          32          24       0.920       0.939       0.929\n",
      "sci.space                          394         325          69          56       0.825       0.853       0.839\n",
      "alt.atheism                        319         243          76          75       0.762       0.764       0.763\n",
      "talk.politics.guns                 364         303          61         114       0.832       0.727       0.776\n",
      "Total                            7,532       5,992       1,540       1,540\n",
      "Accuracy                                                                         0.796\n",
      "Micro Average                                                                    0.796       0.796       0.796\n",
      "Macro Average                                                                    0.790       0.795       0.791\n",
      "Balanced Error Rate                                                              0.210\n"
     ]
    }
   ],
   "source": [
    "var lrTrainer = new LinearSGDTrainer(new LogMulticlass(),new AdaGrad(0.1,0.001),5,42);\n",
    "var bowStartTime = System.currentTimeMillis();\n",
    "var bowModel = lrTrainer.train(bowPair.getA());\n",
    "var bowEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training the model on BoW features took \" + Util.formatDuration(bowStartTime,bowEndTime));\n",
    "System.out.println();\n",
    "var bowEval = labelEvaluator.evaluate(bowModel,bowPair.getB());\n",
    "System.out.println(bowEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We got a macro F1 score of 79.6%, which is a fairly good starting point and it's roughly what other linear models get on this task (e.g., scikit-learn's text classification tutorial gets 76.9% macro F1 when using a similar multinomial Naive Bayes model)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Term counting\n",
    "This simple Bag of Words approach discards a lot of information about the documents, as we're ignoring how many times the word or n-gram appears in the document (also known in information retrieval circles as the Term Frequency or TF). Let's swap the `BasicPipeline` for a `TokenPipeline` which supports term counting via a constructor flag."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "unigram training data size = 11314, number of features = 122024, number of classes = 20\n",
      "unigram testing data size = 7532, number of features = 122024, number of classes = 20\n"
     ]
    }
   ],
   "source": [
    "var unigramPipeline = new TokenPipeline(tokenizer, 1, true);\n",
    "var unigramExtractor = new TextFeatureExtractorImpl<Label>(unigramPipeline);\n",
    "var unigramPair = mkDatasets(\"unigram\",unigramExtractor);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the number of documents and number of features are still the same, all that's different is the feature values within each document. Let's build another logistic regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training the model on Unigram features took (00:00:09:146)\n",
      "\n",
      "Class                                n          tp          fn          fp      recall        prec          f1\n",
      "soc.religion.christian             398         362          36          88       0.910       0.804       0.854\n",
      "rec.autos                          396         353          43          58       0.891       0.859       0.875\n",
      "talk.religion.misc                 251         148         103          97       0.590       0.604       0.597\n",
      "comp.windows.x                     395         295         100          54       0.747       0.845       0.793\n",
      "rec.sport.baseball                 397         356          41          49       0.897       0.879       0.888\n",
      "comp.graphics                      389         280         109         120       0.720       0.700       0.710\n",
      "talk.politics.mideast              376         310          66          29       0.824       0.914       0.867\n",
      "comp.sys.ibm.pc.hardware           392         266         126         133       0.679       0.667       0.673\n",
      "sci.med                            396         310          86          42       0.783       0.881       0.829\n",
      "comp.os.ms-windows.misc            394         241         153          82       0.612       0.746       0.672\n",
      "sci.crypt                          396         354          42          55       0.894       0.866       0.880\n",
      "comp.sys.mac.hardware              385         312          73         103       0.810       0.752       0.780\n",
      "misc.forsale                       390         343          47          69       0.879       0.833       0.855\n",
      "rec.motorcycles                    398         362          36          27       0.910       0.931       0.920\n",
      "talk.politics.misc                 310         171         139          90       0.552       0.655       0.599\n",
      "sci.electronics                    393         289         104         110       0.735       0.724       0.730\n",
      "rec.sport.hockey                   399         374          25          23       0.937       0.942       0.940\n",
      "sci.space                          394         342          52          57       0.868       0.857       0.863\n",
      "alt.atheism                        319         240          79          84       0.752       0.741       0.747\n",
      "talk.politics.guns                 364         314          50         140       0.863       0.692       0.768\n",
      "Total                            7,532       6,022       1,510       1,510\n",
      "Accuracy                                                                         0.800\n",
      "Micro Average                                                                    0.800       0.800       0.800\n",
      "Macro Average                                                                    0.793       0.795       0.792\n",
      "Balanced Error Rate                                                              0.207\n"
     ]
    }
   ],
   "source": [
    "var unigramStartTime = System.currentTimeMillis();\n",
    "var unigramModel = lrTrainer.train(unigramPair.getA());\n",
    "var unigramEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training the model on Unigram features took \" + Util.formatDuration(unigramStartTime,unigramEndTime));\n",
    "System.out.println();\n",
    "var unigramEval = labelEvaluator.evaluate(unigramModel,unigramPair.getB());\n",
    "System.out.println(unigramEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the logistic regression trained on unigrams gets about 80% accuracy, pretty much the same as the BoW baseline, and takes about the same amount of time to run. Both of these make sense, as the term count isn't necessarily that useful in this particular dataset, and we didn't change the number of features overall or inside each example by using term counting.\n",
    "\n",
    "\n",
    "## N-grams as features\n",
    "Let's try a little more complicated feature extractor. The natural step from unigrams is to include word pairs (or bigrams) and count the occurrence of those. This allows us to get simple negations (e.g., \"not bad\" rather than \"not\" and \"bad\") along with places like \"New York\" rather than \"new\" and \"york\". In Tribuo this is as straightforward as telling the token pipeline we'd like bigrams."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bigram training data size = 11314, number of features = 1143035, number of classes = 20\n",
      "bigram testing data size = 7532, number of features = 1143035, number of classes = 20\n"
     ]
    }
   ],
   "source": [
    "var bigramPipeline = new TokenPipeline(tokenizer, 2, true);\n",
    "var bigramExtractor = new TextFeatureExtractorImpl<Label>(bigramPipeline);\n",
    "var bigramPair = mkDatasets(\"bigram\",bigramExtractor);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the feature space has massively increased due to the presence of bigram features, we've now got 1.1 million features from the same 11,314 documents.\n",
    "\n",
    "Now to train another logistic regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training the model on Bigram features took (00:00:43:790)\n",
      "\n",
      "Class                                n          tp          fn          fp      recall        prec          f1\n",
      "soc.religion.christian             398         331          67          57       0.832       0.853       0.842\n",
      "rec.autos                          396         326          70          55       0.823       0.856       0.839\n",
      "talk.religion.misc                 251         167          84         106       0.665       0.612       0.637\n",
      "comp.windows.x                     395         297          98          57       0.752       0.839       0.793\n",
      "rec.sport.baseball                 397         357          40          52       0.899       0.873       0.886\n",
      "comp.graphics                      389         304          85         196       0.781       0.608       0.684\n",
      "talk.politics.mideast              376         300          76          48       0.798       0.862       0.829\n",
      "comp.sys.ibm.pc.hardware           392         244         148         104       0.622       0.701       0.659\n",
      "sci.med                            396         298          98          66       0.753       0.819       0.784\n",
      "comp.os.ms-windows.misc            394         260         134          99       0.660       0.724       0.691\n",
      "sci.crypt                          396         327          69          37       0.826       0.898       0.861\n",
      "comp.sys.mac.hardware              385         320          65         162       0.831       0.664       0.738\n",
      "misc.forsale                       390         352          38         102       0.903       0.775       0.834\n",
      "rec.motorcycles                    398         359          39          39       0.902       0.902       0.902\n",
      "talk.politics.misc                 310         185         125          93       0.597       0.665       0.629\n",
      "sci.electronics                    393         253         140          90       0.644       0.738       0.688\n",
      "rec.sport.hockey                   399         370          29          30       0.927       0.925       0.926\n",
      "sci.space                          394         336          58          40       0.853       0.894       0.873\n",
      "alt.atheism                        319         225          94          65       0.705       0.776       0.739\n",
      "talk.politics.guns                 364         309          55         114       0.849       0.730       0.785\n",
      "Total                            7,532       5,920       1,612       1,612\n",
      "Accuracy                                                                         0.786\n",
      "Micro Average                                                                    0.786       0.786       0.786\n",
      "Macro Average                                                                    0.781       0.786       0.781\n",
      "Balanced Error Rate                                                              0.219\n"
     ]
    }
   ],
   "source": [
    "var bigramStartTime = System.currentTimeMillis();\n",
    "var bigramModel = lrTrainer.train(bigramPair.getA());\n",
    "var bigramEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training the model on Bigram features took \" + Util.formatDuration(bigramStartTime,bigramEndTime));\n",
    "System.out.println();\n",
    "var bigramEval = labelEvaluator.evaluate(bigramModel,bigramPair.getB());\n",
    "System.out.println(bigramEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our performance decreased a little when using bigrams to 78%, and the runtime increased from 10s to 48s. This is because despite there being more information in the features, there are also many, many more features making it easier to confuse this simple linear model plus each example takes longer to process due to the greatly increased number of features. We could look at using a more complex model like boosted trees to exploit this additional information which may increase the performance back above our baseline. We could further increase number of n-gram features but we'll start to see diminishing returns even with more powerful models as the dimensionality of the feature space increases without a commensurate increase in training data.\n",
    "\n",
    "## TFIDF vectors\n",
    "\n",
    "One other factor is that the count of some words isn't usually that helpful, as most documents include \"a\", \"the\", \"and\" many times which just isn't a useful signal. A popular way to deal with this is to scale the term frequencies (i.e., the n-gram counts) by the Inverse Document Frequency (or IDF), producing TF-IDF vectors. In Tribuo the IDF is a transformation which is applied separately to the dataset after it's constructed, as it uses aggregate information from the whole dataset which isn't available until all the examples have been loaded in. Let's see how that affects performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf-idf training data size = 11314, number of features = 1143035, number of classes = 20\n",
      "tf-idf testing data size = 7532, number of features = 316757, number of classes = 20\n"
     ]
    }
   ],
   "source": [
    "// Create a transformation map that contains a single IDFTransformation to apply to every feature\n",
    "var trMap = new TransformationMap(Collections.singletonList(new IDFTransformation()));\n",
    "// Copy out the datasets.\n",
    "var tfidfTrain = MutableDataset.createDeepCopy(bigramPair.getA());\n",
    "var tfidfTest = MutableDataset.createDeepCopy(bigramPair.getB());\n",
    "// Fit the IDF transformation and apply it to the data\n",
    "// We add the implicit zero features (i.e. the words not present in each document)\n",
    "// to get the correct estimate of the IDF.\n",
    "var transformers = tfidfTrain.createTransformers(trMap,true);\n",
    "tfidfTrain.transform(transformers);\n",
    "tfidfTest.transform(transformers);\n",
    "// Print the dataset statistics    \n",
    "System.out.println(String.format(\"tf-idf training data size = %d, number of features = %d, number of classes = %d\",tfidfTrain.size(),tfidfTrain.getFeatureMap().size(),tfidfTrain.getOutputInfo().size()));\n",
    "System.out.println(String.format(\"tf-idf testing data size = %d, number of features = %d, number of classes = %d\",tfidfTest.size(),tfidfTest.getFeatureMap().size(),tfidfTest.getOutputInfo().size()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Creating TF-IDF vectors didn't change the number of features, we still have 1.1 million features in the training set, but it has made the feature values more useful. The irrelevant \"the\" features will have a small value because while they may have a high term frequency, they are also present in every document so they have a high document frequency, so when we divide the two values it'll end up small."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training the model on TF-IDF features took (00:00:45:063)\n",
      "\n",
      "Class                                n          tp          fn          fp      recall        prec          f1\n",
      "soc.religion.christian             398         350          48         183       0.879       0.657       0.752\n",
      "rec.autos                          396         332          64          68       0.838       0.830       0.834\n",
      "talk.religion.misc                 251         155          96         111       0.618       0.583       0.600\n",
      "comp.windows.x                     395         290         105          58       0.734       0.833       0.781\n",
      "rec.sport.baseball                 397         345          52          26       0.869       0.930       0.898\n",
      "comp.graphics                      389         264         125         111       0.679       0.704       0.691\n",
      "talk.politics.mideast              376         306          70          32       0.814       0.905       0.857\n",
      "comp.sys.ibm.pc.hardware           392         285         107         170       0.727       0.626       0.673\n",
      "sci.med                            396         305          91          63       0.770       0.829       0.798\n",
      "comp.os.ms-windows.misc            394         248         146          71       0.629       0.777       0.696\n",
      "sci.crypt                          396         340          56          47       0.859       0.879       0.868\n",
      "comp.sys.mac.hardware              385         283         102          69       0.735       0.804       0.768\n",
      "misc.forsale                       390         340          50          79       0.872       0.811       0.841\n",
      "rec.motorcycles                    398         359          39          36       0.902       0.909       0.905\n",
      "talk.politics.misc                 310         191         119         130       0.616       0.595       0.605\n",
      "sci.electronics                    393         292         101         112       0.743       0.723       0.733\n",
      "rec.sport.hockey                   399         376          23          32       0.942       0.922       0.932\n",
      "sci.space                          394         339          55          52       0.860       0.867       0.864\n",
      "alt.atheism                        319         226          93          57       0.708       0.799       0.751\n",
      "talk.politics.guns                 364         303          61          96       0.832       0.759       0.794\n",
      "Total                            7,532       5,929       1,603       1,603\n",
      "Accuracy                                                                         0.787\n",
      "Micro Average                                                                    0.787       0.787       0.787\n",
      "Macro Average                                                                    0.781       0.787       0.782\n",
      "Balanced Error Rate                                                              0.219\n"
     ]
    }
   ],
   "source": [
    "var tfidfStartTime = System.currentTimeMillis();\n",
    "var tfidfModel = lrTrainer.train(tfidfTrain);\n",
    "var tfidfEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training the model on TF-IDF features took \" + Util.formatDuration(tfidfStartTime,tfidfEndTime));\n",
    "System.out.println();\n",
    "var tfidfEval = labelEvaluator.evaluate(tfidfModel,tfidfTest);\n",
    "System.out.println(tfidfEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using TF-IDF features has roughly the same accuracy as bigrams, so it may be that these features aren't something the linear model can easily operate on on this dataset, but in general the TF-IDF transformation is a useful one when working with text documents.\n",
    "\n",
    "## Feature hashing\n",
    "\n",
    "A popular technique for dealing with large feature spaces is feature hashing. This is where the features are mapped back down to a smaller space using a hash function. It induces collisions between the features, so the model might treat \"New York\" and \"San Francisco\" as the same feature, but the collisions are generated essentially at random based on the hash function which provides a strong regularising effect which can improve performance while making things run faster and use less memory.\n",
    "\n",
    "To use feature hashing in Tribuo simply pass a hash dimension to the `TokenPipeline` on construction. We'll map everything down to 50,000 features, which is around 5% of the original number and see how that affects the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hash-100k training data size = 11314, number of features = 50000, number of classes = 20\n",
      "hash-100k testing data size = 7532, number of features = 50000, number of classes = 20\n"
     ]
    }
   ],
   "source": [
    "var hashPipeline = new TokenPipeline(tokenizer, 2, true, 50000);\n",
    "var hashExtractor = new TextFeatureExtractorImpl<Label>(hashPipeline);\n",
    "var hashPair = mkDatasets(\"hash-100k\",hashExtractor);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected we still have the same number of training & test examples, but now there are only 50,000 features. Let's build another logistic regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training the model on hashed features took (00:00:23:354)\n",
      "\n",
      "Class                                n          tp          fn          fp      recall        prec          f1\n",
      "soc.religion.christian             398         306          92         125       0.769       0.710       0.738\n",
      "rec.autos                          396         324          72          77       0.818       0.808       0.813\n",
      "talk.religion.misc                 251         139         112         132       0.554       0.513       0.533\n",
      "comp.windows.x                     395         273         122          78       0.691       0.778       0.732\n",
      "rec.sport.baseball                 397         335          62          64       0.844       0.840       0.842\n",
      "comp.graphics                      389         238         151         135       0.612       0.638       0.625\n",
      "talk.politics.mideast              376         265         111          35       0.705       0.883       0.784\n",
      "comp.sys.ibm.pc.hardware           392         276         116         178       0.704       0.608       0.652\n",
      "sci.med                            396         251         145         125       0.634       0.668       0.650\n",
      "comp.os.ms-windows.misc            394         254         140         109       0.645       0.700       0.671\n",
      "sci.crypt                          396         305          91          36       0.770       0.894       0.828\n",
      "comp.sys.mac.hardware              385         259         126          97       0.673       0.728       0.699\n",
      "misc.forsale                       390         325          65          87       0.833       0.789       0.810\n",
      "rec.motorcycles                    398         341          57          75       0.857       0.820       0.838\n",
      "talk.politics.misc                 310         171         139         195       0.552       0.467       0.506\n",
      "sci.electronics                    393         243         150         159       0.618       0.604       0.611\n",
      "rec.sport.hockey                   399         353          46          59       0.885       0.857       0.871\n",
      "sci.space                          394         305          89          49       0.774       0.862       0.816\n",
      "alt.atheism                        319         215         104         100       0.674       0.683       0.678\n",
      "talk.politics.guns                 364         292          72         147       0.802       0.665       0.727\n",
      "Total                            7,532       5,470       2,062       2,062\n",
      "Accuracy                                                                         0.726\n",
      "Micro Average                                                                    0.726       0.726       0.726\n",
      "Macro Average                                                                    0.721       0.726       0.721\n",
      "Balanced Error Rate                                                              0.279\n"
     ]
    }
   ],
   "source": [
    "var hashStartTime = System.currentTimeMillis();\n",
    "var hashModel = lrTrainer.train(hashPair.getA());\n",
    "var hashEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training the model on hashed features took \" + Util.formatDuration(hashStartTime,hashEndTime));\n",
    "System.out.println();\n",
    "var hashEval = labelEvaluator.evaluate(hashModel,hashPair.getB());\n",
    "System.out.println(hashEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The performance dropped a little here, but the model has less than a tenth of the parameters compared to the bigram model, making it faster and much smaller at inference time, and it took around 66% of the time to train. In many cases dropping a couple of points of accuracy for a model that is 20x smaller and substantially faster is a worthwhile tradeoff, but as with most machine learning tasks this depends on the problem you're solving and where you're deploying the model. Tuning the hashing dimension and the trainer parameters will likely produce a model with similar accuracy at greatly reduced computational cost.\n",
    "\n",
    "## Trimming out infrequent features\n",
    "\n",
    "We can also directly trim out infrequently occuring features. If a feature doesn't occur very frequently then we're not likely to estimate it's weights properly as we've not seen it very often. Then if it occurs frequently in the test dataset it can confuse the model (this is a form of overfitting to the training data). Let's take the TF-IDF dataset and remove all the bigrams that occur fewer than 5 times."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Minimum cardinality training data size = 11314, number of features = 109743, number of classes = 20\n",
      "Minimum cardinality testing data size = 7532, number of features = 109743, number of classes = 20\n"
     ]
    }
   ],
   "source": [
    "var minCardTrain = new MinimumCardinalityDataset<>(tfidfTrain,5);\n",
    "// This call creates a copy of bigramTest, removing all the \n",
    "// features not found in bigramTrain's feature and output maps\n",
    "var minCardTest = ImmutableDataset.copyDataset(tfidfTest,minCardTrain.getFeatureIDMap(),minCardTrain.getOutputIDInfo());\n",
    "// Print the dataset statistics    \n",
    "System.out.println(String.format(\"Minimum cardinality training data size = %d, number of features = %d, number of classes = %d\",minCardTrain.size(),minCardTrain.getFeatureMap().size(),minCardTrain.getOutputInfo().size()));\n",
    "System.out.println(String.format(\"Minimum cardinality testing data size = %d, number of features = %d, number of classes = %d\",minCardTest.size(),minCardTest.getFeatureMap().size(),minCardTest.getOutputInfo().size()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that's removed about 90% of the features, so let's try our simple model on it again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training the model on trimmed TF-IDF features took (00:00:19:928)\n",
      "\n",
      "Class                                n          tp          fn          fp      recall        prec          f1\n",
      "soc.religion.christian             398         337          61          93       0.847       0.784       0.814\n",
      "rec.autos                          396         312          84          60       0.788       0.839       0.813\n",
      "talk.religion.misc                 251         172          79         143       0.685       0.546       0.608\n",
      "comp.windows.x                     395         290         105          56       0.734       0.838       0.783\n",
      "rec.sport.baseball                 397         344          53          37       0.866       0.903       0.884\n",
      "comp.graphics                      389         284         105         112       0.730       0.717       0.724\n",
      "talk.politics.mideast              376         301          75          19       0.801       0.941       0.865\n",
      "comp.sys.ibm.pc.hardware           392         286         106         217       0.730       0.569       0.639\n",
      "sci.med                            396         295         101          74       0.745       0.799       0.771\n",
      "comp.os.ms-windows.misc            394         219         175          52       0.556       0.808       0.659\n",
      "sci.crypt                          396         322          74          49       0.813       0.868       0.840\n",
      "comp.sys.mac.hardware              385         287          98         125       0.745       0.697       0.720\n",
      "misc.forsale                       390         320          70          62       0.821       0.838       0.829\n",
      "rec.motorcycles                    398         353          45          45       0.887       0.887       0.887\n",
      "talk.politics.misc                 310         191         119         131       0.616       0.593       0.604\n",
      "sci.electronics                    393         298          95         148       0.758       0.668       0.710\n",
      "rec.sport.hockey                   399         370          29          41       0.927       0.900       0.914\n",
      "sci.space                          394         336          58          64       0.853       0.840       0.846\n",
      "alt.atheism                        319         218         101          59       0.683       0.787       0.732\n",
      "talk.politics.guns                 364         302          62         108       0.830       0.737       0.780\n",
      "Total                            7,532       5,837       1,695       1,695\n",
      "Accuracy                                                                         0.775\n",
      "Micro Average                                                                    0.775       0.775       0.775\n",
      "Macro Average                                                                    0.771       0.778       0.771\n",
      "Balanced Error Rate                                                              0.229\n"
     ]
    }
   ],
   "source": [
    "var minCardStartTime = System.currentTimeMillis();\n",
    "var minCardModel = lrTrainer.train(minCardTrain);\n",
    "var minCardEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training the model on trimmed TF-IDF features took \" + Util.formatDuration(minCardStartTime,minCardEndTime));\n",
    "System.out.println();\n",
    "var minCardEval = labelEvaluator.evaluate(minCardModel,minCardTest);\n",
    "System.out.println(minCardEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As with the feature hashing above, this model trains more quickly because there is less data to process, but the speed improvement is more substantial as the number of features in each example is lower (because the hashing produces a denser example than trimming out infrequent features). Performance dropped slightly as compared to the TF-IDF model, but again it is around 10% of the parameters, with a corresponding reduction in memory and runtime in inference and training. Performance is improved over the hashing as we're not colliding features at random, we're simply removing ones which are infrequent. If a feature is infrequent we probably can't estimate the weight for it very well so it helps remove some of the noise.\n",
    "\n",
    "Choosing which one of feature hashing and trimming out infrequent features to apply is problem dependent. Feature hashing can work in denser feature spaces than trimming infrequent features, but both still require some amount of sparsity in the problem to have any useful effect. With text datasets then trimming the infrequent words/features is usually helpful."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Word embeddings\n",
    "\n",
    "All the approaches described above have no notion of word similarity, they rely upon exactly the same words with the same spelling appearing in the training and test documents, when in practice word similarity is likely to be very useful information for the classifier because no two documents use exactly the same phrasing. For example, the unigrams \"excellent\" and \"fantastic\" are equally dissimilar to an n-gram model, when in fact those words are quite similar in meaning. Adding notions of word similarity to ML models usually means embedding each word into some vector space, then words with similar meanings can be close in the vector space, and words with dissimilar or opposite meanings are far apart. There are many popular word embedding algorithms, like [Word2Vec](https://arxiv.org/abs/1301.3781), [GloVe](https://nlp.stanford.edu/projects/glove/) or [FastText](https://fasttext.cc/) which build embeddings on a corpus of text that can then be used in downstream tasks. Tribuo doesn't have a class which can directly load those word vectors, as they all come in different file formats, but it's pretty straightforward to build a `TextFeatureExtractor` that will tokenize the input text, look up each word or n-gram in the vector space and then average them across the input (it took us about an afternoon to build one for our internal word2vec style word vector research file format). If there is interest from the community in supporting a specific word vector file format, we're happy to accept PRs that add the support.\n",
    "\n",
    "While these more traditional forms of word vector are very powerful, as they are precomputed they treat each word the same no matter the context it appears in. For example \"bank\" could mean a river bank, or a financial institution, but a word2vec vector has to contain both meanings because it doesn't know the *context* the word is present in, i.e., the rest of the sentence. This led to the rise of *contextual* word embeddings, which produce a vector for each word based on the whole input sequence. The most popular of these embeddings are based on the [Transformer](https://papers.nips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html) architecture, usually a variant of Google's [BERT](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html) model.\n",
    "\n",
    "## Using BERT embeddings\n",
    "\n",
    "BERT is a multi-layer transformer network, which reads in a sentence and produces both an embedding of the sentence, along with embeddings for each wordpiece. A \"wordpiece\" is the token that BERT operates on, which is either a whole word, or a chunk of a word, emitted by the wordpiece tokenizer. The word chunking algorithm is trained on a large corpus and allows common prefixes & suffixes (e.g. \"un\", \"ing\") to be split off the words and to share state. We can use BERT to produce a single vector which represents the sentence or document and then use that vector as features in a downstream Tribuo classifier.\n",
    "\n",
    "Tribuo works with BERT models that are stored in [ONNX format](https://onnx.ai), and can load in tokenizers produced by [HuggingFace Transformers](https://huggingface.co/transformers/). That package also helpfully provides a Python script to convert BERT models from HuggingFace format into ONNX format for deployment. We provide a `TextFeatureExtractor` implementation called `BERTFeatureExtractor` which can produce sentence embeddings out by passing the text through a BERT model. Tribuo uses Microsoft's [ONNX Runtime](https://www.onnxruntime.ai/) to load the model, and has it's own implementation of the Wordpiece tokenization algorithm, along with the necessary glue to produce tokens in the format that BERT expects. One downside of BERT models is that they have a maximum document length that they can process, usually 512 wordpieces. This is configurable in Tribuo's extractor, but if you set the maximum length to be longer than the sequences the model was trained on then the performance is likely to suffer (or the computation may fail depending on how that specific BERT model is implemented).\n",
    "\n",
    "To follow along with this part of the tutorial you'll need to produce a BERT model in onnx format. To do that you'll need access to a Python 3 environment with HuggingFace and PyTorch or TensorFlow installed to export the model (the snippet below assumes PyTorch, change the `pt` to `tf` if you're using TensorFlow). Running the following snippet will produce a `bert-base-uncased.onnx` file that we can use for the rest of the tutorial. You'll need to run it in an empty directory due to the way HuggingFace's conversion script works.\n",
    "\n",
    "```\n",
    "python -m transformers.convert_graph_to_onnx --framework pt --model bert-base-uncased bert-base-uncased.onnx\n",
    "```\n",
    "\n",
    "You'll also need to download the `tokenizer.json` that goes with the BERT variant you are using, for `bert-base-uncased` that file is [here](https://huggingface.co/bert-base-uncased/blob/main/tokenizer.json). Assuming both of those files are now in the same directory as this tutorial, we can create the `BERTFeatureExtractor`. We're going to take the average token embedding across the whole input, as the `[CLS]` token which provides the sentence embedding  tends to perform poorly unless it is fine-tuned on your task.\n",
    "\n",
    "Warning: this feature extraction step took more than a minute per newsgroup on a 2019 16\" 6-core MacBook Pro (using the default settings of ONNX Runtime i.e., using a single thead on the CPU provider) so around 55 minutes to extract the full train and test datasets. Your mileage may vary, and your laptop may get quite warm. We recommend not running it while your laptop is actually on your lap. At the moment Tribuo's `TextFeatureExtractor` interface doesn't batch up the inputs, which limits the performance of contextual feature extractors. We'll look at expanding that interface to support batching in a future release. The session options used can be controlled by the `BERTFeatureExtractor.reconfigureOrtSession(SessionOptions options)` method, which allows the use of whatever configuration is supported by your onnxruntime jar."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bert training data size = 11314, number of features = 768, number of classes = 20\n",
      "bert testing data size = 7532, number of features = 768, number of classes = 20\n",
      "Extracting features with BERT took (00:38:37:476)\n"
     ]
    }
   ],
   "source": [
    "var bertPath = Paths.get(\"./bert-base-uncased.onnx\");\n",
    "var tokenizerPath = Paths.get(\"./tokenizer.json\");\n",
    "var bert = new BERTFeatureExtractor<>(labelFactory,\n",
    "                                      bertPath,\n",
    "                                      tokenizerPath,\n",
    "                                      BERTFeatureExtractor.OutputPooling.MEAN,\n",
    "                                      256,  // Maximum number of wordpiece tokens\n",
    "                                      false // Use Nvidia GPUs for inference (if onnxruntime_gpu is on the classpath)\n",
    "                                      );\n",
    "                                      \n",
    "var bertStartTime = System.currentTimeMillis();\n",
    "var bertPair = mkDatasets(\"bert\",bert);\n",
    "var bertEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Extracting features with BERT took \" + Util.formatDuration(bertStartTime,bertEndTime));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note Tribuo's `BERTFeatureExtractor` can run the BERT embeddings on a GPU, but only if the onnxruntime_gpu jar is on the classpath. By default Tribuo pulls in the CPU only jar for maximum compatibility. As you can see from the time taken to extract the features, it's best to deploy BERT when you've got plenty of CPUs or fast GPUs.\n",
    "\n",
    "Now we build a logistic regression on the dense feature space produced by BERT. These embeddings are dense 768 dimensional vectors, each document contains a value for each one of those dimensions. In Tribuo 4.1 we added optimisations to several of the models and trainers to improve their performance on the dense feature spaces produced by techniques like BERT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training a LR on BERT features took (00:00:06:082)\n",
      "Class                                n          tp          fn          fp      recall        prec          f1\n",
      "soc.religion.christian             398         353          45         111       0.887       0.761       0.819\n",
      "rec.autos                          396         332          64          99       0.838       0.770       0.803\n",
      "talk.religion.misc                 251         102         149         131       0.406       0.438       0.421\n",
      "comp.windows.x                     395         288         107         121       0.729       0.704       0.716\n",
      "rec.sport.baseball                 397         365          32          32       0.919       0.919       0.919\n",
      "comp.graphics                      389         257         132         183       0.661       0.584       0.620\n",
      "talk.politics.mideast              376         289          87          26       0.769       0.917       0.836\n",
      "comp.sys.ibm.pc.hardware           392         220         172         166       0.561       0.570       0.566\n",
      "sci.med                            396         320          76          34       0.808       0.904       0.853\n",
      "comp.os.ms-windows.misc            394         247         147         187       0.627       0.569       0.597\n",
      "sci.crypt                          396         314          82          95       0.793       0.768       0.780\n",
      "comp.sys.mac.hardware              385         134         251          32       0.348       0.807       0.486\n",
      "misc.forsale                       390         342          48         103       0.877       0.769       0.819\n",
      "rec.motorcycles                    398         308          90          75       0.774       0.804       0.789\n",
      "talk.politics.misc                 310         186         124         226       0.600       0.451       0.515\n",
      "sci.electronics                    393         252         141         197       0.641       0.561       0.599\n",
      "rec.sport.hockey                   399         381          18          21       0.955       0.948       0.951\n",
      "sci.space                          394         332          62          78       0.843       0.810       0.826\n",
      "alt.atheism                        319         163         156         121       0.511       0.574       0.541\n",
      "talk.politics.guns                 364         210         154          99       0.577       0.680       0.624\n",
      "Total                            7,532       5,395       2,137       2,137\n",
      "Accuracy                                                                         0.716\n",
      "Micro Average                                                                    0.716       0.716       0.716\n",
      "Macro Average                                                                    0.706       0.715       0.704\n",
      "Balanced Error Rate                                                              0.294\n"
     ]
    }
   ],
   "source": [
    "var bertStartTime = System.currentTimeMillis();\n",
    "var bertModel = lrTrainer.train(bertPair.getA());\n",
    "var bertEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Training a LR on BERT features took \" + Util.formatDuration(bertStartTime,bertEndTime));\n",
    "var bertEval = labelEvaluator.evaluate(bertModel,bertPair.getB());\n",
    "System.out.println(bertEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We get around 71% accuracy using this standard BERT model, which might be due to it's training data of Wikipedia and books not overlapping well with the comparatively old newsgroup language. Fine tuning the BERT model on a large corpus of newsgroups could probably improve this, but the standard model is likely to work well for more well formed text like news articles or more formal documents. Alternatively it may be that the logistic regression we're training isn't sufficiently flexible to use the information in the BERT features, so it may be beneficial to use a more complex classifier like gradient boosted trees or a Multi-Layer Perceptron through Tribuo's TensorFlow interface.\n",
    "\n",
    "Using different BERT versions can change the accuracy as there are variants fine-tuned for a wide variety of different tasks & domains, and there are smaller versions like DistillBERT and TinyBERT which are useful for deploying models in constrained environments. However BERT based feature extractors will always be slower than the simpler BoW approaches described above, because they have to perform lots of floating point computations to compute the embedded feature values.\n",
    "\n",
    "# Deploying the feature extractors\n",
    "\n",
    "Similarly to when working with columnar data, the feature extractor used is recorded in the model provenance. We can see that for the BERT model here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DirectoryFileSource(\n",
      "\tclass-name = org.tribuo.data.text.DirectoryFileSource\n",
      "\tdataDir = /Users/apocock/Development/Tribuo/tutorials/20news/20news-bydate-train\n",
      "\tpreprocessors = List[\n",
      "\t\tNewsPreprocessor(\n",
      "\t\t\t\t\tclass-name = org.tribuo.data.text.impl.NewsPreprocessor\n",
      "\t\t\t\t\thost-short-name = DocumentPreprocessor\n",
      "\t\t\t\t)\n",
      "\t\tCasingPreprocessor(\n",
      "\t\t\t\t\tclass-name = org.tribuo.data.text.impl.CasingPreprocessor\n",
      "\t\t\t\t\top = LOWERCASE\n",
      "\t\t\t\t\thost-short-name = DocumentPreprocessor\n",
      "\t\t\t\t)\n",
      "\t]\n",
      "\textractor = BERTFeatureExtractor(\n",
      "\t\t\tclass-name = org.tribuo.interop.onnx.extractors.BERTFeatureExtractor\n",
      "\t\t\tuseCUDA = false\n",
      "\t\t\tpooling = MEAN\n",
      "\t\t\tmodelPath = /Users/apocock/Development/Tribuo/tutorials/bert-base-uncased.onnx\n",
      "\t\t\ttokenizerPath = /Users/apocock/Development/Tribuo/tutorials/tokenizer.json\n",
      "\t\t\toutputFactory = LabelFactory(\n",
      "\t\t\t\t\tclass-name = org.tribuo.classification.LabelFactory\n",
      "\t\t\t\t)\n",
      "\t\t\tmaxLength = 256\n",
      "\t\t\thost-short-name = FeatureExtractor\n",
      "\t\t)\n",
      "\toutputFactory = LabelFactory(\n",
      "\t\t\tclass-name = org.tribuo.classification.LabelFactory\n",
      "\t\t)\n",
      "\tfile-modified-time = 2003-03-18T07:24:55-05:00\n",
      "\tdatasource-creation-time = 2021-12-18T20:50:57.169758-05:00\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "var sourceProvenance = bertModel.getProvenance().getDatasetProvenance().getSourceProvenance();\n",
    "System.out.println(ProvenanceUtil.formattedProvenanceString(sourceProvenance));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This means that the model has recorded how the features were extracted, but the extraction process itself isn't part of the serialized model (which we wouldn't really want anyway as BERT models are hundreds of megabytes). So to use one of these models at inference time the feature extraction pipeline needs to be rebuilt from the configuration, in the same way we rebuilt the `RowProcessor` in the columnar tutorial.\n",
    "\n",
    "Each of the different models trained in this tutorial has recorded the source provenance and it's associated `TextFeatureExtractor` configuration, meaning the models come with all the necessary information to infer the classes of new documents.\n",
    "\n",
    "# Conclusion\n",
    "\n",
    "We looked at a document classification task in Tribuo. As most of the work in NLP tends to be on featurising the data, we discussed several different ways of converting text into features for use in machine learning. We looked at Bag of Words models, using n-grams, term frequencies, TFIDF vectors, feature hashing and also looked at trimming large feature spaces based on the number of times we'd seen a feature. We also discussed word vector approaches, and showed how to use the popular contextual word embedding model, BERT, to extract features for document classification. It's worth noting all the models trained were simple logistic regressions, with no parameter tuning. Using a more powerful classifier like XGBoost, or performing hyperparameter tuning on the logistic regression will likely improve performance over the simple baselines presented here.\n",
    "\n",
    "Tribuo's text processing framework is very flexible, and it's possible to insert your own code into each of the different classes by implementing `TextFeatureExtractor`, `TextPipeline` or even the `Tokenizer` yourself, while the provenance system ensures that you can always recover how your data was processed to ensure it matches at inference time."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "17+35-LTS-2724"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
